#####
# Script to generate figure 4.
#####

library(ggplot2)
library(cowplot)
library(patchwork)
library(viridis)
library(tidyverse)
library(pracma)
library(nleqslv)
library(rootSolve)
library(fields)
library(doParallel)
library(scales)
library(data.table)

source("equations.r")
source("simulation_functions.r")
source("planner_simulation.r")

avg_sat_decay <- 1 # 1: satellites are infinitely lived
p <- 1
F <- 35
discount <- 0.99
r <- ((1 - discount)/(discount))
fe_eqm <- p/F - r #the equilibrium risk
d <- 0.45
m <- 2
aSS <- 0.005
aSD <- 0.005
aDD <- 0.025 
bSS <- 10
bSD <- 5.375
bDD <- 0.9
T <- 150
params <- data.frame(p=p, F=F, discount=discount, r=r, fe_eqm=fe_eqm, d=d, m=m, aSS=aSS, aSD=aSD, aDD=aDD, bSS=bSS, bSD=bSD, bDD=bDD, T=T)
params2 <- params
params2$fe_eqm <- params2$fe_eqm*2

# generate a satellite-debris grid
S_element <- (seq(0,10,length.out=250))
D_element <- (seq(0,10,length.out=250))
SD_grid <- expand.grid(S_element,D_element)
colnames(SD_grid) <- c("satellites","debris")

# find the surface of one-step jump points
oass <- ss_finder(c(0,0),params)
oass_ray_start <- oass - c(oass[2]/m,oass[2])
ray_pt_to_trace <- oass_ray_start+c(0.2,m*0.2)
ray_pt_origin <- physics_inverter(ray_pt_to_trace,params)
oass_ray <- cbind(S=seq(oass_ray_start[1],oass[1],length.out=length(S_element)),
								 D=seq(oass_ray_start[2],oass[2],length.out=length(S_element)))
oass_ray_origin_surface <- data.frame(matrix(-1,nrow=nrow(oass_ray),ncol=ncol(oass_ray))) 
for(i in 1:nrow(oass_ray)) {
	oass_ray_origin_surface[i,] <- physics_inverter(oass_ray[i,],params)
}
colnames(oass_ray_origin_surface) <- c("sats","debs")

## Calibration scale factors.
s_scale_factor <- 68289.84/3.728625
d_scale_factor <- 367372.3/4.214421

# calculate equilibrium values
fig5_dfrm <- data.frame(SD_grid) %>% 
				mutate(oa_launches = eqm_launch_rate(satellites,debris,params)) %>%
				mutate(oa_S_ = S_(oa_launches,satellites,debris,params)) %>%
				mutate(oa_D_ = D_(oa_launches,satellites,debris,params)) %>%
				mutate(numerical_coll_rate = L(oa_S_,oa_D_,params)) %>%
				mutate(oa_coll_rate = eqm_sat_stock(debris,params)) %>%
				mutate(oa_coll_rate_2 = eqm_sat_stock(debris,params2)) %>%
				mutate(sat_null = sats_SS(satellites,debris,params)) %>%
				mutate(deb_null = debs_SS(satellites,debris,params))

clp_d1 <- gen_launch_path(init=c(0,0),150,0.25,params,fig5_dfrm)
clp_d2 <- gen_launch_path(init=c(0,0),150,0.05,params,fig5_dfrm)
lp_d1 <- gen_launch_path(init=c(0,0),150,0,params,fig5_dfrm)

lp_s1 <- gen_launch_path(init=c(2.7,0),150,0,params,fig5_dfrm)
lp_s2 <- gen_launch_path(init=c(3.5,0),150,0,params,fig5_dfrm)

lp_sj1 <- gen_launch_path(init=ray_pt_origin,150,0,params,fig5_dfrm)

cd1_launch_path <- OA.ts(0,0,150,0,opt=0,constraint=0,params=params)
cd1_sats <- as.data.frame(cd1_launch_path)$satellites
cd1_debs <- as.data.frame(cd1_launch_path)$debris
length(cd1_sats) <- nrow(fig5_dfrm)
length(cd1_debs) <- nrow(fig5_dfrm)

s2_launch_path <- OA.ts(3.5,0,150,0,opt=0,constraint=0,params=params)
s2_sats <- as.data.frame(s2_launch_path)$satellites
s2_debs <- as.data.frame(s2_launch_path)$debris
length(s2_sats) <- nrow(fig5_dfrm)
length(s2_debs) <- nrow(fig5_dfrm)

sj1_launch_path <- OA.ts(ray_pt_origin[1],ray_pt_origin[2],150,0,opt=0,constraint=0,params=params)
sj1_sats <- as.data.frame(sj1_launch_path)$satellites
sj1_debs <- as.data.frame(sj1_launch_path)$debris
length(sj1_sats) <- nrow(fig5_dfrm)
length(sj1_debs) <- nrow(fig5_dfrm)

fig5_time_dfrm <- left_join( as.data.frame(
	as.tibble(cd1_launch_path) %>% 
		mutate(
			launches=launches*s_scale_factor,
			satellites=satellites*s_scale_factor,
			debris=debris*d_scale_factor)
), as.data.frame(as.tibble(s2_launch_path) %>%
	mutate(
		launches=launches*s_scale_factor,
		satellites=satellites*s_scale_factor,
		debris=debris*d_scale_factor)
), by=c("time"), suffix=c("_cd1","_s2"))
fig5_time_dfrm <- left_join( fig5_time_dfrm, as.data.frame(
	as.tibble(sj1_launch_path) %>%
	mutate(
		launches=launches*s_scale_factor,
		satellites=satellites*s_scale_factor,
		debris=debris*d_scale_factor)
	), by=c("time"), suffix=c("","_sj1"))

oass_ray_origin_surface_scaled <- oass_ray_origin_surface %>% 
	mutate(sats=sats*s_scale_factor,
				 debs=debs*d_scale_factor)
clp_d1_scaled <- clp_d1 %>% 
	mutate(lp_sats=lp_sats*s_scale_factor,
				 lp_debs=lp_debs*d_scale_factor)
clp_d2_scaled <- clp_d2 %>% 
	mutate(lp_sats=lp_sats*s_scale_factor,
				 lp_debs=lp_debs*d_scale_factor)
lp_d1_scaled <- lp_d1 %>%
	mutate(lp_sats=lp_sats*s_scale_factor,
				 lp_debs=lp_debs*d_scale_factor)
lp_s1_scaled <- lp_s1 %>%
	mutate(lp_sats=lp_sats*s_scale_factor,
				 lp_debs=lp_debs*d_scale_factor)
lp_s2_scaled <- lp_s2 %>%
	mutate(lp_sats=lp_sats*s_scale_factor,
				 lp_debs=lp_debs*d_scale_factor)
lp_sj1_scaled <- lp_sj1 %>%
	mutate(lp_sats=lp_sats*s_scale_factor,
				 lp_debs=lp_debs*d_scale_factor)

fig5_dfrm <- cbind(fig5_dfrm,
					oass_surf=oass_ray_origin_surface_scaled,
					cd1=clp_d1_scaled,
					cd2=clp_d2_scaled,
					d1=lp_d1_scaled,
					s1=lp_s1_scaled,
					s2=lp_s2_scaled,
					sj1=lp_sj1_scaled)

head(fig5_dfrm)

fig5_dfrm <- fig5_dfrm %>%
	mutate(
		oa_launches = oa_launches*s_scale_factor,
		satellites = satellites*s_scale_factor,
		debris = debris*d_scale_factor,
		oa_S_ = oa_S_*s_scale_factor,
		oa_D_ = oa_D_*d_scale_factor,
		sat_null = sat_null*s_scale_factor,
		deb_null = deb_null*d_scale_factor,
		oa_coll_rate = oa_coll_rate*s_scale_factor,
		oa_coll_rate_2 = oa_coll_rate_2*s_scale_factor
	)

head(fig5_dfrm)

oass[1] <- oass[1]*s_scale_factor
oass[2] <- oass[2]*d_scale_factor

ray_start_idx <- fig5_dfrm %>%
	filter(oass_surf.debs==0) %>%
	select(oass_surf.sats) %>%
	unlist() %>% as.numeric() %>%
	which.max()

figure_5a <- ggplot(data=fig5_dfrm[which(fig5_dfrm$oa_coll_rate>=0),]) + 
					xlim(0,4*d_scale_factor) + ylim(0,4*s_scale_factor) +
					geom_line(aes(x=debris,y=oa_coll_rate), size=1) + 
					geom_segment(x=0,y=0, xend=5*d_scale_factor,yend=0, size=0.15) + geom_segment(x=0,y=0, xend=0,yend=5*s_scale_factor, size=0.15) +
					geom_segment(x=fig5_dfrm$oass_surf.debs[ray_start_idx],y=fig5_dfrm$oass_surf.sats[ray_start_idx], xend=oass[2],yend=oass[1], size=1, linetype="dotted") +
					geom_path(aes(x=d1.lp_debs,y=d1.lp_sats), linetype="dashed", color=viridis(9)[7], size=1.25) +
					geom_path(aes(x=s2.lp_debs,y=s2.lp_sats), linetype="dashed", color=viridis(9)[3], size=1.25) +
					geom_path(aes(x=sj1.lp_debs,y=sj1.lp_sats), linetype="dashed", color=viridis(9)[5], size=1.25) +
					annotate(geom = "text", x = fig5_dfrm$sj1.lp_debs[1], y = fig5_dfrm$sj1.lp_sats[1], label = paste(strwrap("3")), hjust = 0, vjust = 1.5, size = 7, color = viridis(9)[5]) +
					annotate(geom = "text", x = fig5_dfrm$s2.lp_debs[3], y = fig5_dfrm$s2.lp_sats[3], label = paste(strwrap("2")), hjust = -1, vjust = 1.5, size = 7, color = viridis(9)[3]) +
					annotate(geom = "text", x = fig5_dfrm$d1.lp_debs[7], y = fig5_dfrm$d1.lp_sats[7], label = paste(strwrap("1")), hjust = 0, vjust = 1.5, size = 7, color = viridis(9)[7]) +
					geom_point(aes(x=oass[2]*d_scale_factor,y=oass[1]*s_scale_factor), size=5) +
					ylab("Satellites") + xlab("Debris") +
					scale_x_continuous(limits=c(0,4*d_scale_factor)) + 
					scale_y_continuous(limits=c(0,4*s_scale_factor)) +
					theme_bw() + theme(text=element_text(family="Arial", size=15),panel.grid.minor = element_blank())


figure_5b <- ggplot(data=fig5_dfrm[which(fig5_dfrm$oa_coll_rate>=0),]) + 
					xlim(0,4*d_scale_factor) + ylim(0,4*s_scale_factor) +
					geom_line(aes(x=debris,y=oa_coll_rate), size=1) + 
					stat_contour(aes(x=debris,y=satellites,z=sat_null, group=..level..), size=1, breaks=c(0), color="#313695") +
					stat_contour(aes(x=debris,y=satellites,z=deb_null, group=..level..), size=1, breaks=c(0), color="#A50026") +
					geom_text(label = "Dn",	x = 2.75, y = 3.75, color="#A50026", size=7) +
					geom_text(label = "Sn",	x = 3.9, y = 1.2, color="#313695", size=7) +
					geom_text(label = "Em",	x = 3.1, y = 0.9, color="black", size=7) +
					geom_text(label = "1s", x = 0.5, y = 2.1, color="black", size=7) +
					geom_text(label = "Oa", x = 1.15, y = 2.37, color="black", size=7) +
					geom_point(aes(x=oass_surf.debs, y=oass_surf.sats), size=1) +
					geom_segment(x=0,y=0, xend=5*d_scale_factor,yend=0, size=0.15) + geom_segment(x=0,y=0, xend=0,yend=5*s_scale_factor, size=0.15) +
					geom_point(aes(x=oass[2]*d_scale_factor,y=oass[1]*s_scale_factor), size=5) +
					# horizontal arrow from (0.5,0.5)
					geom_segment(	x=0.5*d_scale_factor, 
									y=0.5*s_scale_factor, 
									xend=(0.5+0.15)*d_scale_factor, 
									yend=(0.5)*s_scale_factor, 
									arrow=arrow(length= unit(0.25,"cm"))	) +
					# vertical arrow from (0.5,0.5)
					geom_segment(	x=0.5*d_scale_factor, 
									y=0.5*s_scale_factor, 
									xend=(0.5)*d_scale_factor, 
									yend=(0.5+0.15)*s_scale_factor, 
									arrow=arrow(length= unit(0.25,"cm"))	) +
					# horizontal arrow from (3.5,0.5)
					geom_segment(	x=3.5*d_scale_factor, 
									y=0.5*s_scale_factor, 
									xend=(3.5-0.15)*d_scale_factor, 
									yend=(0.5)*s_scale_factor, 
									arrow=arrow(length= unit(0.25,"cm"))	) +
					# vertical arrow from (3.5,0.5)
					geom_segment(	x=3.5*d_scale_factor, 
									y=0.5*s_scale_factor, 
									xend=(3.5)*d_scale_factor, 
									yend=(0.5+0.15)*s_scale_factor, 
									arrow=arrow(length= unit(0.25,"cm"))	) +
					# horizontal arrow from (3.5,3.75)
					geom_segment(	x=3.5*d_scale_factor, 
									y=3.75*s_scale_factor, 
									xend=(3.5-0.15)*d_scale_factor, 
									yend=(3.75)*s_scale_factor, 
									arrow=arrow(length= unit(0.25,"cm"))	) +
					# vertical arrow from (3.5,3.75)
					geom_segment(	x=3.5*d_scale_factor, 
									y=3.75*s_scale_factor, 
									xend=(3.5)*d_scale_factor, 
									yend=(3.75-0.15)*s_scale_factor, 
									arrow=arrow(length= unit(0.25,"cm"))	) +
					# horizontal arrow from (0.5,3.75)
					geom_segment(	x=0.5*d_scale_factor, 
									y=3.75*s_scale_factor, 
									xend=(0.5+0.15)*d_scale_factor, 
									yend=(3.75)*s_scale_factor, 
									arrow=arrow(length= unit(0.25,"cm"))	) +
					# vertical arrow from (0.5,3.75)
					geom_segment(	x=0.5*d_scale_factor, 
									y=3.75*s_scale_factor, 
									xend=(0.5)*d_scale_factor, 
									yend=(3.75-0.15)*s_scale_factor, 
									arrow=arrow(length= unit(0.25,"cm"))	) +
					ylab("Satellites") + xlab("Debris") +
					scale_x_continuous(limits=c(0,4*d_scale_factor)) + 
					scale_y_continuous(limits=c(0,4*s_scale_factor)) +
					theme_bw() + theme(text=element_text(family="Arial", size=15),panel.grid.minor = element_blank())

figure_4time_a <- ggplot(data=fig5_time_dfrm[1:25,], aes(x=time)) +
				geom_line(aes(y=launches), linetype="dashed", color=viridis(9)[5], size=1) + 
				geom_line(aes(y=launches_s2), linetype="dashed", color= viridis(9)[3], size=1) + 
				geom_line(aes(y=launches_cd1), linetype="dashed", color= viridis(9)[7], size=1) + 
				geom_segment(x=0,y=0, xend=35,yend=0, size=0.15) + geom_segment(x=0,y=0, xend=0,yend=5*d_scale_factor, size=0.15) +
				ylab("Launches") + xlab("Time") +
				theme_bw() + theme(text=element_text(family="Arial", size=15))

figure_4time_b <- ggplot(data=fig5_time_dfrm[1:25,], aes(x=time)) +
				geom_line(aes(y=risk), linetype="dashed", color=viridis(9)[5], size=1) + 
				geom_line(aes(y=risk_s2), linetype="dashed", color= viridis(9)[3], size=1) + 
				geom_line(aes(y=risk_cd1), linetype="dashed", color= viridis(9)[7], size=1) + 
				ylab("Collision risk") + xlab("Time") +
				geom_segment(x=0,y=0, xend=35,yend=0, size=0.15) + geom_segment(x=0,y=0, xend=0,yend=5*d_scale_factor, size=0.15) +
				theme_bw() + theme(text=element_text(family="Arial", size=15))

figure_4time_c <- ggplot(data=fig5_time_dfrm[1:25,], aes(x=time)) +
				geom_line(aes(y=satellites), linetype="dashed", color=viridis(9)[5], size=1) + 
				geom_line(aes(y=satellites_s2), linetype="dashed", color= viridis(9)[3], size=1) + 
				geom_line(aes(y=satellites_cd1), linetype="dashed", color= viridis(9)[7], size=1) + 
				ylab("Satellites") + xlab("Time") +
				geom_segment(x=0,y=0, xend=35,yend=0, size=0.15) + geom_segment(x=0,y=0, xend=0,yend=5*d_scale_factor, size=0.15) +
				theme_bw() + theme(text=element_text(family="Arial", size=15))


figure_4time_d <- ggplot(data=fig5_time_dfrm[1:25,], aes(x=time)) +
				geom_line(aes(y=debris), linetype="dashed", color=viridis(9)[5], size=1) + 
				geom_line(aes(y=debris_s2), linetype="dashed", color= viridis(9)[3], size=1) + 
				geom_line(aes(y=debris_cd1), linetype="dashed", color= viridis(9)[7], size=1) + 
				ylab("Debris") + xlab("Time") +
				geom_segment(x=0,y=0, xend=35,yend=0, size=0.15) + geom_segment(x=0,y=0, xend=0,yend=5*d_scale_factor, size=0.15) +
				theme_bw() + theme(text=element_text(family="Arial", size=15))

big_phase_panel <- ggplot(data=fig5_dfrm[which(fig5_dfrm$oa_coll_rate>=0),]) + 

					geom_segment(x=0,y=0, xend=5*d_scale_factor,yend=0, size=0.15) + geom_segment(x=0,y=0, xend=0,yend=5*s_scale_factor, size=0.15) +
					geom_segment(x=fig5_dfrm$oass_surf.debs[ray_start_idx],y=fig5_dfrm$oass_surf.sats[ray_start_idx], xend=oass[2],yend=oass[1], size=1, linetype="dotted") +
					geom_path(aes(x=d1.lp_debs,y=d1.lp_sats), linetype="dashed", color=viridis(9)[7], size=1.5) +
					geom_path(aes(x=s2.lp_debs,y=s2.lp_sats), linetype="dashed", color=viridis(9)[3], size=1.5) +
					geom_path(aes(x=sj1.lp_debs,y=sj1.lp_sats), linetype="dashed", color=viridis(9)[5], size=1.5) +
					annotate(geom = "text", x = fig5_dfrm$sj1.lp_debs[1], y = fig5_dfrm$sj1.lp_sats[1], label = paste(strwrap("3")), hjust = 0, vjust = 1.5, size = 7, color = viridis(9)[5]) +
					annotate(geom = "text", x = fig5_dfrm$s2.lp_debs[3], y = fig5_dfrm$s2.lp_sats[3], label = paste(strwrap("2")), hjust = -1, vjust = 1.5, size = 7, color = viridis(9)[3]) +
					annotate(geom = "text", x = 1*d_scale_factor, y = 0.6*s_scale_factor, label = paste(strwrap("1")), hjust = 0, vjust = 1.5, size = 7, color = viridis(9)[7]) +
					ylab("Satellites") + xlab("Debris")  +
					geom_line(aes(x=debris,y=oa_coll_rate), size=1, alpha=0.5) + 
					stat_contour(aes(x=debris,y=satellites,z=sat_null, group=..level..), size=1, breaks=c(0), color="#313695", alpha = 0.5) +
					stat_contour(aes(x=debris,y=satellites,z=deb_null, group=..level..), size=1, breaks=c(0), color="#A50026", alpha = 0.5) +
					geom_text(label = "Dn",	x = 2.75*d_scale_factor, y = 3.75*s_scale_factor, color="#A50026", size=7) +
					geom_text(label = "Sn",	x = 3.9*d_scale_factor, y = 1.2*s_scale_factor, color="#313695", size=7) +
					geom_text(label = "Em",	x = 3.1*d_scale_factor, y = 0.9*s_scale_factor, color="black", size=7) +
					geom_text(label = "1s", x = 0.5*d_scale_factor, y = 2.1*s_scale_factor, color="black", size=7) +
					geom_text(label = "Oa", x = 1.15*d_scale_factor, y = 2.37*s_scale_factor, color="black", size=7) +
					geom_segment(x=0,y=0, xend=5*d_scale_factor,yend=0, size=0.15) + geom_segment(x=0,y=0, xend=0,yend=5*s_scale_factor, size=0.15) +
					geom_point(aes(x=oass[2],y=oass[1]), size=5) +
					# horizontal arrow from (0.5,0.5)
					geom_segment(	x=0.5*d_scale_factor, 
									y=0.5*s_scale_factor, 
									xend=(0.5+0.15)*d_scale_factor, 
									yend=(0.5)*s_scale_factor, 
									arrow=arrow(length= unit(0.25,"cm"))	) +
					# vertical arrow from (0.5,0.5)
					geom_segment(	x=0.5*d_scale_factor, 
									y=0.5*s_scale_factor, 
									xend=(0.5)*d_scale_factor, 
									yend=(0.5+0.15)*s_scale_factor, 
									arrow=arrow(length= unit(0.25,"cm"))	) +
					# horizontal arrow from (3.5,0.5)
					geom_segment(	x=3.5*d_scale_factor, 
									y=0.5*s_scale_factor, 
									xend=(3.5-0.15)*d_scale_factor, 
									yend=(0.5)*s_scale_factor, 
									arrow=arrow(length= unit(0.25,"cm"))	) +
					# vertical arrow from (3.5,0.5)
					geom_segment(	x=3.5*d_scale_factor, 
									y=0.5*s_scale_factor, 
									xend=(3.5)*d_scale_factor, 
									yend=(0.5+0.15)*s_scale_factor, 
									arrow=arrow(length= unit(0.25,"cm"))	) +
					# horizontal arrow from (3.5,3.75)
					geom_segment(	x=3.5*d_scale_factor, 
									y=3.75*s_scale_factor, 
									xend=(3.5-0.15)*d_scale_factor, 
									yend=(3.75)*s_scale_factor, 
									arrow=arrow(length= unit(0.25,"cm"))	) +
					# vertical arrow from (3.5,3.75)
					geom_segment(	x=3.5*d_scale_factor, 
									y=3.75*s_scale_factor, 
									xend=(3.5)*d_scale_factor, 
									yend=(3.75-0.15)*s_scale_factor, 
									arrow=arrow(length= unit(0.25,"cm"))	) +
					# horizontal arrow from (0.5,3.75)
					geom_segment(	x=0.5*d_scale_factor, 
									y=3.75*s_scale_factor, 
									xend=(0.5+0.15)*d_scale_factor, 
									yend=(3.75)*s_scale_factor, 
									arrow=arrow(length= unit(0.25,"cm"))	) +
					# vertical arrow from (0.5,3.75)
					geom_segment(	x=0.5*d_scale_factor, 
									y=3.75*s_scale_factor, 
									xend=(0.5)*d_scale_factor, 
									yend=(3.75-0.15)*s_scale_factor, 
									arrow=arrow(length= unit(0.25,"cm"))	)  +
				scale_x_continuous(labels=comma, limits=c(0,4*d_scale_factor)) + 
				scale_y_continuous(labels=comma, limits=c(0,4*s_scale_factor)) +
					theme_bw() + theme(text=element_text(family="Arial", size=15),panel.grid.minor = element_blank())

big_phase_panel_label <- 
	plot_grid(big_phase_panel + coord_flip(), align="h", axis="1", nrow=1, ncol=1, rel_widths=c(1/2,1/2), labels=c("A"))

fig4_time_panel <- plot_grid(figure_4time_a, figure_4time_b, figure_4time_c, figure_4time_d, align="v", labels=c("B","C","D","E"), nrow=4)

ggsave(
	plot_grid(big_phase_panel_label, fig4_time_panel, align="h", axis="1", nrow=1, ncol=2, rel_widths=c(1/2,1/2), label_size=15),
	   filename=paste0("../../images/figure-4--oa-phase-trajectories.png"),
	   width=16, height=8, units="in", dpi=300)